function [rules_sep, rules_comb] = summarize_savings_graph_baseline(covariate,wave,cost,winter_prevuse,SMC,fbase)

% This function plots the treatment rules learned using the EWM method and the 
% heatmap illustrating the solution of the grid search algorithm,
% specifically for rules only based on baseline energy consumption.

% Inputs: 
% (1) covariate: minimum of baseline consumption, maximum of baseline 
% consumption, or standard deviation of consumption
% (2) wave: indicates whether the output is wave-specific (=3,6,or 7) or
% for the pooled sample (=0)
% (3) cost: =true if welfare measure is in terms of cost savings using private marginal cost 
% =false if welfare measure is in terms of kWh reduction
% (4) winter_prevuse: indicates whether baseline consumption is calculated 
% as the mean of consumption in winter months (Jan and Feb) or as the mean
% of specified pre-treatment periods.
% (5) SMC: =1 use social marginal cost; =0 use retail electricity price
% (6) fbase: string used to indicate the propensity score and baseline
% months specification for output table 

% Output: 
% (1) rules_sep: separate plots showing the quadrant and cubic treatment rule
% (2) rules_comb: combine quadrant and cubic treatment rule in one plot

%% Import treatment rule results
savings=nan(2,7);
savings=array2table(savings,'VariableNames',{'rule','percent','savings_hh','number_hh','total_savings','ci_lb','ci_ub'});

% quadrant results
if (cost)
    tcost = 0.765; 
else
    tcost = 0;
end

switch winter_prevuse 
    case 1
        s_winter = '_winter';
    case 0
        s_winter = '';
end
if (tcost)>0 
    if (SMC)
        s_cost = 'SMC';
    else
        s_cost = 'PMC';
    end
elseif (tcost)==0
    s_cost='kwh';
end
filename_coefs=sprintf('coef_quadrant_%s_%s%s_wave%1.0f.mat',covariate,s_cost,s_winter,wave);
file = sprintf('%s_%s',fbase,filename_coefs);
S_quadrant=load(file);

%extract content of S_quadrant
c_fieldnames = fieldnames(S_quadrant);
for ifield = 1:length(c_fieldnames)
    field_ = c_fieldnames{ifield};
    eval(sprintf('%s=S_quadrant.%s;',field_,field_))
end

in_Ghat_quadrant=in_Ghat;
vhats_quadrant=nonzeros(max([vhats_p -vhats_n],[],2));

percent_quadrant=sum(nw.*in_Ghat_quadrant)/n;
savings_hh_quadrant=sum(gu.*in_Ghat_quadrant)/n;
number_hh=n;
if (cost)
    total_savings_quadrant=sum(gu.*in_Ghat_quadrant)*12;
else
    total_savings_quadrant=sum(gu.*in_Ghat_quadrant)*12/1000;
end
ci_lb_quadrant=(minW-prctile(vhats_quadrant,95))/n;
ci_ub_quadrant=(minW+prctile(vhats_quadrant,95))/n;

savings.rule(1)=1; % 1 for quadrant; 2 for cubic
savings.percent(1)=percent_quadrant;
savings.savings_hh(1)=savings_hh_quadrant;
savings.number_hh(1)=number_hh;
savings.total_savings(1)=total_savings_quadrant;
savings.ci_lb(1)=ci_lb_quadrant;
savings.ci_ub(1)=ci_ub_quadrant;

% cubic results
switch winter_prevuse 
    case 1
        s_winter = '_winter';
    case 0
        s_winter = '';
end
if (tcost)>0 
    if (SMC)
        s_cost = 'SMC';
    else
        s_cost = 'PMC';
    end
elseif (tcost)==0
    s_cost='kwh';
end
filename_coefs=sprintf('coef_cubic_%s_%s%s_wave%1.0f.mat',covariate,s_cost,s_winter,wave);
file = sprintf('%s_%s',fbase,filename_coefs);
S=load(file);

%extract content of S
c_fieldnames = fieldnames(S);
for ifield = 1:length(c_fieldnames)
    field_ = c_fieldnames{ifield};
    eval(sprintf('%s=S.%s;',field_,field_))
end

% "beta" conflicts with a built-in function, have to do this extra
% assignment
beta=S.beta;
in_Ghat_cubic=in_Ghat;
vhats_cubic=nonzeros(max([vhats_p vhats_n],[],2));

percent_cubic=mean(in_Ghat_cubic);
savings_hh_cubic=nanmean(g.*in_Ghat_cubic)*Yscale;
number_hh=n;
if (cost)
    total_savings_cubic=nanmean(g.*in_Ghat_cubic)*Yscale*n*12;
else
    total_savings_cubic=nanmean(g.*in_Ghat_cubic)*Yscale*n*12/1000;
end
ci_lb_cubic=(v-prctile(vhats_cubic,95))*Yscale/n;
ci_ub_cubic=(v+prctile(vhats_cubic,95))*Yscale/n;

savings.rule(2)=2;
savings.percent(2)=percent_cubic;
savings.savings_hh(2)=savings_hh_cubic;
savings.number_hh(2)=number_hh;
savings.total_savings(2)=total_savings_cubic;
savings.ci_lb(2)=ci_lb_cubic;
savings.ci_ub(2)=ci_ub_cubic;


%% Separate plots for quadrant and cubic rules
rules_sep=figure('position',[100,100,1100,350]); hold on;
subplot(1,2,1); hold on;
gscatter(Xu(:,1), Xu(:,2), in_Ghat_quadrant, 'br', '..');
axis([min(v_x1),max(v_x1),min(v_x2),max(v_x2)]);
if covariate=="min"
    xlabel('Minimum of pre-treatment usage [kWh/mo]');
elseif covariate=="max"
    xlabel('Maximum of pre-treatment usage [kWh/mo]');
elseif covariate=="std"
    xlabel('Standard deviation of pre-treatment usage [kWh/mo]');
end
scatter(x1_wave,prevuse_wave,0.1,'MarkerEdgeColor','none','MarkerFaceColor',[0.7 .7 0.7]);
if (cost)
    title('Baseline-based EWM rule, maximize cost savings')
else
    title('Baseline-based EWM rule, maximize energy conservation')
end
ylabel('Average of pre-treatment usage [kWh/mo]')
legend('\color{blue} Not treat','\color{red} Treat','Population density')
    
subplot(1,2,2); hold on;
    gscatter(x1_wave,prevuse_wave, in_Ghat_cubic, 'br', '..');
    if covariate=="min"
        xlabel('Minimum of pre-treatment usage [kWh/mo]');
    elseif covariate=="max"
        xlabel('Maximum of pre-treatment usage [kWh/mo]');
    elseif covariate=="std"
        xlabel('Standard deviation of pre-treatment usage [kWh/mo]');
    end
    scatter(x1_wave,prevuse_wave,0.1,'MarkerEdgeColor','none','MarkerFaceColor',[0.7 .7 0.7]);
    if (cost)
        title('Baseline-based EWM rule, maximize cost savings')
    else
        title('Baseline-based EWM rule, maximize energy conservation')
    end
    ylabel('Average of pre-treatment usage [kWh/mo]')
    legend('\color{blue} Not treat','\color{red} Treat','Population density')
hold off;
%% Plot quadrant and cubic results together 
% quadrant plot inputs
x0=10;
y0=10;
width=800;
height=600;

rules_comb=figure; hold on;
set(gcf,'position',[x0,y0,width,height]);

if min_sign1<0 && min_sign2<0
    x_coord=[min(v_x1) min(v_x1) v_x1(min_i1) v_x1(min_i1)];
    y_coord=[min(v_x2) v_x2(min_i2) v_x2(min_i2) min(v_x2)];
elseif min_sign1<0 && min_sign2>0 
    x_coord=[min(v_x1) min(v_x1) v_x1(min_i1) v_x1(min_i1)];
    y_coord=[v_x2(min_i2) max(v_x2) max(v_x2) v_x2(min_i2)];
elseif min_sign1>0 && min_sign2<0 
    x_coord=[v_x1(min_i1) v_x1(min_i1) max(v_x1) max(v_x1)];
    y_coord=[min(v_x2) v_x2(min_i2) v_x2(min_i2) min(v_x2)];
elseif min_sign1>0 && min_sign2>0    
    x_coord=[v_x1(min_i1) v_x1(min_i1) max(v_x1) max(v_x1)];
    y_coord=[v_x2(min_i2) max(v_x2) max(v_x2) v_x2(min_i2)];
end 

% cubic plot inputs 
line_x1=[min(x1_wave):1:max(x1_wave)]';
% calculate the usage cutoff for treatment at different income levels 
line_usage=Xscale(1,4)*...
               (-(beta(1)+line_x1*beta(2)./Xscale(1,1)...
               +(line_x1.^2)*beta(3)./Xscale(1,2)...
               +(line_x1.^3)*beta(4)./Xscale(1,3))./beta(5));
           
if (beta(5)>0) % treatment rule is increasing in pre-treatment consumption
    select=(line_usage<=max(prevuse_wave));
    patch_x=[line_x1(select); flipud(line_x1(select))];
    patch_y=[max(line_usage(select),min(prevuse_wave)); max(prevuse_wave)*ones(sum(select),1) ];
else           % treatment rule is decreasing in pre-treatment consumption
    select=(line_usage>=min(prevuse_wave));
    patch_x=[line_x1(select); flipud(line_x1(select))];
    patch_y=[min(line_usage(select),max(prevuse_wave)); min(prevuse_wave)*ones(sum(select),1) ];
end

patch (x_coord, y_coord, 'red', 'LineStyle', '-', 'FaceAlpha', 0.3,'EdgeColor', 'none')
patch (patch_x,patch_y,'blue', 'LineStyle', '-', 'FaceAlpha', 0.45,'EdgeColor', 'none')
xmin = min(min(x_coord),min(patch_x)); xmax= max(max(x_coord),max(patch_x));
ymin = min(min(y_coord),min(patch_y)); ymax= max(max(y_coord),max(patch_y));
if strcmp(covariate,'min')
    xmin = min(xmin,ymin); xmax= max(xmax,ymax);
    xx=[xmin,xmax,xmax]; yy=[xmin,xmin,xmax]; 
    patch (xx,yy,[1 1 1]*0.8,'EdgeColor','None','FaceAlpha',1);
    %hatch
    N_hatch = 20; x2=linspace(xmin,xmax,N_hatch); 
    col_hatch=[1,1,1]*0.5;
    plot([x2;x2],[xmin*ones(1,N_hatch);x2],'linewidth',0.5,'color',col_hatch)
    plot([x2;xmax*ones(1,N_hatch)],[x2;x2],'linewidth',0.5,'color',col_hatch)
elseif strcmp(covariate,'max')
    xmin = min(xmin,ymin); xmax= max(xmax,ymax);
    xx=[xmin,xmax,xmin]; yy=[xmin,xmax,xmax]; 
    patch (xx,yy,[1 1 1]*0.8,'EdgeColor','None','FaceAlpha',1);
    %hatch
    N_hatch = 20; x2=linspace(xmin,xmax,N_hatch); 
    col_hatch=[1,1,1]*0.5;
    plot([x2;x2],[x2;xmax*ones(1,N_hatch)],'linewidth',0.5,'color',col_hatch)
    plot([xmin*ones(1,N_hatch);x2],[x2;x2],'linewidth',0.5,'color',col_hatch)
end

axis([min(v_x1),max(v_x1),min(v_x2),max(v_x2)]);
ylabel('Pre-treatment Usage [kWh/mo]')
legend('Quadrant rule treat','Cubic rule treat','Not applicable')
if covariate=="min"
    xlabel('Minimum of pre-treatment usage [kWh/mo]');
elseif covariate=="max"
    xlabel('Maximum of pre-treatment usage [kWh/mo]');
elseif covariate=="std"
    xlabel({'SD of pre-treatment usage [kWh/mo]'});
end
set(findall(gcf,'-property','FontSize'),'FontSize',24)

hold off;
end